import json
import os
from tqdm import tqdm
from collections import defaultdict
from azfuse import File
import random

import mimetypes
import os
from io import BytesIO
from typing import Union
import cv2
import base64
import time
import argparse
import requests
import torch
import random
from tqdm import tqdm
# import transformers
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from tqdm import tqdm
import sys
# from otter.modeling_otter import OtterForConditionalGeneration

import json

# PROMPT = """
# You are given a pair of very similar images. In image 2, there is a specific object that is missing or changed from image 1. Generate 3 questions about that object.

# There are a few rules to follow for each question:
# 1. The question should be answerable for image 1, that is there is a definitive answer to the question, just by looking at image 1. 
# 2. The question should not be answerable for image 2. "Not answerable" means, just by looking at image 2, the answer would be something like "I don't know" or "I don't see it".
# 3. The question should be relevant to the content of each image alone, even without seeing the other image. 

# The response should be formatted as:
# - Q1: <question>
#   A1: <answer for image 1>
# - Q2: <question>
#   A2: <answer for image 1>
# - Q3: <question>
#   A3: <answer for image 1>
# """


PROMPT = """
You are given a pair of very similar images. In image 2, there is a specific object that is missing or changed from image 1. Generate a question that is answerable for image 1 while not answerable for image 2.

There are a few rules to follow for each question:
1. The question should be answerable for image 1, that is there is a definitive answer to the question, just by looking at image 1. 
2. The question should not be answerable for image 2. "Not answerable" means, just by looking at image 2, the answer would be something like "I don't know" or "I don't see it".
3. The question should be relevant to the content of each image alone, even without seeing the other image. 

The response should be formatted as:
- Q: <question>
- A1: <answer for image 1>
- A2: <answer for image 2>
"""

# Disable warnings
requests.packages.urllib3.disable_warnings()

# ------------------- Utility Functions -------------------


def get_content_type(file_path):
    content_type, _ = mimetypes.guess_type(file_path)
    return content_type


# ------------------- Image and Video Handling Functions -------------------


def get_image(url: str) -> Union[Image.Image, list]:
    if "://" not in url:  # Local file
        content_type = get_content_type(url)
    else:  # Remote URL
        content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")

    if "image" in content_type:
        if "://" not in url:  # Local file
            return Image.open(url)
        else:  # Remote URL
            return Image.open(requests.get(url, stream=True, verify=False).raw)
    else:
        raise ValueError("Invalid content type. Expected image or video.")


def gptv_query_paired_image(orig_img, modified_img, prompt, temp=0.):

    GPT4V_KEY = os.environ.get("GPT4V_KEY")
    headers = {
        "Content-Type": "application/json",
        "api-key": GPT4V_KEY,
        'cogsvc-openai-gptv-disable-faceblur': 'true',
    }

    data = {
        'max_tokens':2048, 
        'temperature': temp,
        'top_p': 0.5,
        'messages':[]
    }
    data['messages'] = [{"role": "user", "content": []}] ##  But they do not contain the final answer so do not output them as part of the final answer.

    ## Query Prompt
    data['messages'][-1]["content"].append(prompt)

    assert orig_img is not None and modified_img is not None
    assert os.path.exists(orig_img) and os.path.exists(modified_img)
    data['messages'][-1]["content"].append('Image 1:')
    data['messages'][-1]["content"].append({'image': base64.b64encode(open(orig_img, "rb").read()).decode()})
    data['messages'][-1]["content"].append('Image 2:')
    data['messages'][-1]["content"].append({'image': base64.b64encode(open(modified_img, "rb").read()).decode()})


    response_text, retry, response_json, regular_time = '', 0, None, 60
    while len(response_text)<2:
        retry += 1
        try:
            response = requests.post(os.environ["GPTV_API_BASE"], headers=headers, data=json.dumps(data)) 
            response_json = response.json()
        except Exception as e:
            print(e)
            time.sleep(regular_time)
            continue
        if response.status_code != 200:
            print(response.headers,response.content)
            print(orig_img, modified_img)
            if random.random()<1: print(f"The response status code for is {response.status_code} (Not OK)")
            time.sleep(regular_time)
            data['temperature'] = min(data['temperature'] + 0.1, 1.0)
            continue
        if 'choices' not in response_json:
            time.sleep(regular_time)
            continue
        # response_text = response_json["choices"][0]["text"]
        response_text = response_json["choices"][0]["message"]["content"]
        if response_text.lower().strip().startswith("I'm sorry"):
            time.sleep(regular_time)
            continue
    # return [response_json["choices"][0]["text"]]
    return response_json["choices"][0]["message"]["content"]

# ------------------- Main Function -------------------

def test_single(prompt=PROMPT, orig_img=None, mask_img=None, temp=0):
    
    # print(prompt)
    
    response = gptv_query_paired_image(orig_img, mask_img, prompt, temp=temp)
    # print(f"Response:\n\t\t{response}")
    return response


def main(mask_img_folder, annotation_file, output_folder, orig_img_folder, debug=False, temp=0.7, overwrite=False):
    from collections import defaultdict
    import shutil
    os.makedirs(output_folder, exist_ok=True)
    data = [json.loads(line.strip()) for line in open(annotation_file, "r")]
    if debug:
        data = data[:20]
    
    for d in tqdm(data):
        file_id = d["vg_id"]
        orig_img = os.path.join(orig_img_folder, f'{file_id}.jpg')
        assert os.path.exists(orig_img)
        if not os.path.exists(os.path.join(output_folder, os.path.basename(orig_img))) or overwrite:
            shutil.copyfile(orig_img, os.path.join(output_folder, os.path.basename(orig_img)))
        # get the list of objects that are with the same name
        name2obj = defaultdict(list)
        for obj_id, item in d["scene_graph"].items():
            if len(item["relations"]) == 0 or len(item["attributes"]) == 0:
                continue
            name = item["name"]
            if name in ["background"]:
                continue
            name2obj[name].append(obj_id)
        for name, obj_ids in tqdm(name2obj.items(), desc=f"Processing {file_id}"):
            obj_ids = sorted(obj_ids)
            output_filename = f'{file_id}.{"-".join(obj_ids)}'
            output_file = os.path.join(output_folder, f"{output_filename}.txt")
            mask_img_file = os.path.join(mask_img_folder, f'{output_filename}_remove_0.png')
            if not os.path.exists(mask_img_file):
                print(f"Mask file {mask_img_file} does not exist")
                continue
            if not os.path.exists(os.path.join(output_folder, f'{output_filename}_remove_0.png')) or overwrite:
                shutil.copyfile(mask_img_file, os.path.join(output_folder, f'{output_filename}_remove_0.jpg'))
            response = test_single(orig_img=orig_img, mask_img=mask_img_file, temp=temp)
            with open(output_file, "w") as f:
                f.write(response)



def match_to_question(annotation_file, image_folder, raw_data, output_file, debug=False, remove_multi_obj=True):
    # annotation_file = './data/gqa/val_balanced_gqa_coco_captions_region_captions_scene_graphs.jsonl'
    # image_folder = 'data/vg_samples/remove_anything/lama_text_box'
    scene_graph_data = [json.loads(line.strip()) for line in open(annotation_file, "r")]
    raw_data = json.load(open(raw_data, "r"))
    if debug:
        scene_graph_data = scene_graph_data[:100]
    # output_file = image_folder.replace("lama_box", "lama_box_q.jsonl")
    f = open(output_file, "w")
    variants_of_idk_answers_for_q = [
        "I don't know.",
        "I don't see any {}.",
        "There is no {} in the image.",
        "I can't see any {}.",
    ]
    for d in tqdm(scene_graph_data):
        file_id = d["vg_id"]
        name2obj = defaultdict(list)
        scene_graph = d["scene_graph"]
        for obj_id, item in scene_graph.items():
            if len(item["relations"]) == 0 or len(item["attributes"]) == 0:
                continue
            name = item["name"]
            if name in ["background"]:
                continue
            name2obj[name].append(obj_id)

        question_ids = d["question_ids"]
        questions = d["questions"]
        answers = d["answers"]
        for q, ans, qid in zip(questions, answers, question_ids):
            out = {"question_id": qid, "image_id": file_id, "question": q, "answer": ans, "perturbed_image": []}
            if q.lower().startswith("are there") or q.lower().startswith("is there") or q.lower().startswith("do you see"):
                continue
            question = raw_data[qid]
            annotations = question["annotations"]['question']
            object_ids = list(annotations.values())
            object_names = list(set([scene_graph[oid]["name"] for oid in object_ids if oid in scene_graph]))
            if len(object_names) > 1 and remove_multi_obj:
                continue
            images = []
            idk_answers = []
            for name in object_names:
                if name not in name2obj:
                    continue
                obj_ids = name2obj[name]
                perturbed_image_filename = f'{file_id}.{"-".join(obj_ids)}'
                perturbed_image_path = os.path.join(image_folder, f'{perturbed_image_filename}_remove_0.png')
                if name.lower() in q.lower():
                    perturbed_answer_idx = random.choice(list(range(len(variants_of_idk_answers_for_q))))
                    perturbed_answer = variants_of_idk_answers_for_q[perturbed_answer_idx]
                    if "I don't know" not in perturbed_answer:
                        perturbed_answer = perturbed_answer.format(name)
                    if q.lower().startswith("who"):
                        p = random.random()
                        if p > 0.5:
                            perturbed_answer = "No one."
                else:
                    continue
                if not os.path.exists(perturbed_image_path):
                    # print(f"Missing {perturbed_image_path}")
                    continue
                images.append(os.path.basename(perturbed_image_path))
                idk_answers.append(perturbed_answer)
            if len(images) == 0:
                continue
            out["perturbed_image"] = images
            out['perturbed_answer'] = idk_answers
            f.write(json.dumps(out) + "\n")
    f.close()


def sample_subset(jsonl_file, num_samples=2500):
    data = [json.loads(line.strip()) for line in File.open(jsonl_file, "r")]
    data = data[:num_samples]
    output_file = jsonl_file.replace(".jsonl", f"_sample_{num_samples}.jsonl")

    with File.open(output_file, "w") as f:
        for d in data:
            f.write(json.dumps(d) + "\n")
    

def convert_to_llava_format(jsonl_file, output_json, image_folder="data", split="train"):
    data = [json.loads(line.strip()) for line in File.open(jsonl_file, "r")]
    clean_image_folder = "gqa/images/"
    perturb_image_folder = f"gqa/{split}_lama_box"
    out_data = []
    for d in data:
        clean_image_path = os.path.join(clean_image_folder, f"{d['image_id']}.jpg")
        assert File.isfile(os.path.join(image_folder, clean_image_path)), f"Missing {clean_image_path}"
        out_dict_clean = {"id": str(d["question_id"])+"_clean", "image": clean_image_path, "conversations": [{"from": "human", "value": "<image>\n" + d["question"]}, {"from": "gpt", "value": d["answer"]}]}
        out_data.append(out_dict_clean)
        for p_img, p_ans in zip(d['perturbed_image'], d['perturbed_answer']):
            preturb_image_path = os.path.join(perturb_image_folder, p_img)
            # assert File.isfile(os.path.join(image_folder, preturb_image_path))
            out_dict_perturb = {"id": str(d["question_id"])+"_"+p_img, "image": preturb_image_path, "conversations": [{"from": "human", "value": "<image>\n" + d["question"]}, {"from": "gpt", "value": p_ans}]}
            out_data.append(out_dict_perturb)
    with File.open(output_json, "w") as f:
        json.dump(out_data, f, indent=2)


def merge_with_llava_data(llava_json, pertubed_gqa_json, output_json):
    llava_data = json.load(File.open(llava_json, "r"))
    gqa_data = json.load(File.open(pertubed_gqa_json, "r"))
    out_data = llava_data + gqa_data
    with File.open(output_json, "w") as f:
        json.dump(out_data, f, indent=2)


def merge_with_our_unk_data(unk_json, pertubed_gqa_json, output_json):
    unk_data = json.load(File.open(unk_json, "r"))
    unk_data_fix_img_folder = []
    image_folder = "lama-gpt4v_gen_q"
    for d in unk_data:
        if image_folder not in d["image"]:
            d["image"] = os.path.join(image_folder, d["image"])
        unk_data_fix_img_folder.append(d)
    gqa_data = json.load(File.open(pertubed_gqa_json, "r"))
    out_data = unk_data_fix_img_folder + gqa_data
    with File.open(output_json, "w") as f:
        json.dump(out_data, f, indent=2)


if __name__ == "__main__":
    from fire import Fire
    Fire()


                      
